{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import sys\n",
    "import time\n",
    "import _init_paths\n",
    "from core.config import cfg, merge_cfg_from_file, merge_cfg_from_list, assert_and_infer_cfg\n",
    "from core.test_engine import run_inference, get_inference_dataset\n",
    "from datasets.json_dataset import JsonDataset\n",
    "import pickle\n",
    "import json\n",
    "from utils.boxes import xywh_to_xyxy, bbox_overlaps\n",
    "import ipdb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg.TEST.DATASETS = ('vcoco_test',)\n",
    "cfg.MODEL.NUM_CLASSES = 81\n",
    "cfg.TEST.PRECOMPUTED_PROPOSALS = False\n",
    "cfg.MODEL.VCOCO_ON = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "ican_res_file = '../ican/ican_vcoco_det_res.json'\n",
    "vcoco_det_res = json.load(open(ican_res_file, 'r'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_path = '../Outputs/e2e_faster_rcnn_R-50-FPN_1x_vcoco/Nov26-09-29-09_wanbo_node35_step/test/bbox_vcoco_test_results.json'\n",
    "# file_path = '../data/test/bbox_vcoco_test_results.json'\n",
    "vcoco_det_res = json.load(open(file_path, 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading annotations into memory...\n",
      "Done (t=1.26s)\n",
      "creating index...\n",
      "index created!\n",
      "loading vcoco annotations...\n"
     ]
    }
   ],
   "source": [
    "dataset_name, proposal_file = get_inference_dataset(0, is_parent=False)\n",
    "dataset = JsonDataset(dataset_name)\n",
    "vcocodb = dataset.get_roidb(gt=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "119061"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(vcoco_det_res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "for det in vcoco_det_res:\n",
    "    im_id = det['image_id']\n",
    "    tmp = results.setdefault(im_id, [])\n",
    "    tmp.append(np.concatenate([det['bbox'], [det['score'], det['category_id']]]))\n",
    "    results[im_id] = tmp\n",
    "\n",
    "for k,v in results.items():\n",
    "    v = np.array(v).astype(np.float32)\n",
    "    trans = xywh_to_xyxy(v[:, :4])\n",
    "    v[:, :4] = trans\n",
    "    results[k] = v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "no_gt_image_num = 0\n",
    "no_im_id = 0\n",
    "total_num = [0, 0]\n",
    "recall_num = [0, 0]\n",
    "\n",
    "for gt in vcocodb:\n",
    "#     gt = vcocodb[1]\n",
    "    im_id = gt['id']\n",
    "\n",
    "    gt_boxes = gt['boxes'].astype(np.float32)\n",
    "    gt_cls = gt['gt_classes']\n",
    "    gt_actions = gt['gt_actions']\n",
    "    gt_role_inds = gt['gt_role_id']\n",
    "    gt_human_inds, gt_act_inds, gt_r_ids = np.where(gt_role_inds>-1)\n",
    "\n",
    "    if len(gt_human_inds) == 0:\n",
    "        no_gt_image_num += 1\n",
    "        continue\n",
    "\n",
    "    gt_obj_inds = gt_role_inds[gt_human_inds, gt_act_inds, gt_r_ids]\n",
    "    gt_human_inds = np.unique(gt_human_inds)\n",
    "    gt_obj_inds = np.unique(gt_obj_inds)\n",
    "    gt_human_box = gt_boxes[gt_human_inds]\n",
    "    gt_obj_box = gt_boxes[gt_obj_inds]\n",
    "\n",
    "    total_num[0] += len(gt_human_box)\n",
    "    total_num[1] += len(gt_obj_box)\n",
    "\n",
    "    if results.get(im_id) is None:\n",
    "        no_im_id += 1\n",
    "        continue\n",
    "    pred_res = results[im_id]\n",
    "    pred_boxes = pred_res[:, :4]\n",
    "    pred_scores = pred_res[:, -2]\n",
    "    pred_cls = pred_res[:, -1].astype(np.int)\n",
    "\n",
    "    pred_obj_ind = np.where(pred_cls>1)\n",
    "    if len(pred_obj_ind[0]) == 0:\n",
    "        continue\n",
    "    pred_obj_box = pred_boxes[pred_obj_ind]\n",
    "    pred_obj_score = pred_scores[pred_obj_ind]\n",
    "\n",
    "    valid = np.where(pred_obj_score>=0.4)\n",
    "    if len(valid[0]) > 0:\n",
    "        valid_obj_box = pred_obj_box[valid]\n",
    "        obj_overlap = bbox_overlaps(gt_obj_box, valid_obj_box)\n",
    "    #     obj_overlap = bbox_overlaps(gt_obj_box, pred_obj_box)\n",
    "        try:\n",
    "            obj_max_overlap = obj_overlap.max(1)\n",
    "        except:\n",
    "            ipdb.set_trace()\n",
    "        recall_num[1] += np.sum(obj_max_overlap>=0.5)\n",
    "\n",
    "    pred_human_ind = np.where(pred_cls==1)\n",
    "    if len(pred_human_ind[0]) == 0:\n",
    "        continue\n",
    "    pred_human_box = pred_boxes[pred_human_ind]\n",
    "    pred_human_scores = pred_scores[pred_human_ind]\n",
    "    valid = np.where(pred_human_scores>=0.5)\n",
    "    if len(valid[0]) > 0:\n",
    "        valid_human_box = pred_human_box[valid]\n",
    "        human_overlap = bbox_overlaps(gt_human_box, valid_human_box)\n",
    "    #     human_overlap = bbox_overlaps(gt_human_box, pred_human_box)\n",
    "        human_max_overlap = human_overlap.max(1)\n",
    "        recall_num[0] += np.sum(human_max_overlap>=0.5)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[6287, 7020]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "total_num"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[6106, 5699]"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "recall_num"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8118233618233618"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "5699/7020"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9712104342293622"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "6106/6287"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "407"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "no_gt_image_num"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "numpy.ndarray"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(results[165])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path = '../data/cache/vcoco_test_precomp_boxes.json'\n",
    "for k,v in results.items():\n",
    "    results[k] = v.tolist()\n",
    "    \n",
    "with open(save_path, 'w') as f:\n",
    "    json.dump(results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8358314479638009"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "5911/7072"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4946"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
